#  Copyright (c) ZenML GmbH 2020. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at:
#
#       https://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
#  or implied. See the License for the specific language governing
#  permissions and limitations under the License.
"""Logger implementation."""

import builtins
import logging
import os
import re
import sys
from contextvars import ContextVar
from typing import TYPE_CHECKING, Any, Dict

if TYPE_CHECKING:
    from zenml.logging.step_logging import ArtifactStoreHandler

from rich.traceback import install as rich_tb_install

from zenml.constants import (
    ENABLE_RICH_TRACEBACK,
    ENV_ZENML_CAPTURE_PRINTS,
    ENV_ZENML_LOGGING_COLORS_DISABLED,
    ENV_ZENML_LOGGING_FORMAT,
    ENV_ZENML_SUPPRESS_LOGS,
    ZENML_LOGGING_VERBOSITY,
    ZENML_STORAGE_LOGGING_VERBOSITY,
    handle_bool_env_var,
)
from zenml.enums import LoggingLevels
from zenml.utils.context_utils import ContextVarList

ZENML_LOGGING_COLORS_DISABLED = handle_bool_env_var(
    ENV_ZENML_LOGGING_COLORS_DISABLED, False
)


step_names_in_console: ContextVar[bool] = ContextVar(
    "step_names_in_console", default=False
)
logging_handlers: ContextVarList["ArtifactStoreHandler"] = ContextVarList(
    "logging_handlers"
)


def _add_step_name_to_message(message: str) -> str:
    """Adds the step name to the message.

    Args:
        message: The message to add the step name to.

    Returns:
        The message with the step name added.
    """
    try:
        if step_names_in_console.get():
            from zenml.steps import get_step_context

            step_context = get_step_context()

            if step_context and message not in ["\n", ""]:
                # For progress bar updates (with \r), inject the step name after the \r
                if "\r" in message:
                    message = message.replace(
                        "\r", f"\r[{step_context.step_name}] "
                    )
                else:
                    message = f"[{step_context.step_name}] {message}"
    except Exception:
        # If we can't get step context, just use the original message
        pass

    return message


class CustomFormatter(logging.Formatter):
    """Formats logs according to custom specifications."""

    grey: str = "\x1b[90m"
    white: str = "\x1b[37m"
    pink: str = "\x1b[35m"
    green: str = "\x1b[32m"
    yellow: str = "\x1b[33m"
    red: str = "\x1b[31m"
    cyan: str = "\x1b[1;36m"
    bold_red: str = "\x1b[31;1m"
    purple: str = "\x1b[38;5;105m"
    blue: str = "\x1b[34m"
    reset: str = "\x1b[0m"

    def _get_format_template(self, record: logging.LogRecord) -> str:
        """Get the format template based on the logging level.

        Args:
            record: The log record to format.

        Returns:
            The format template string.
        """
        # Only include location info for DEBUG level
        if get_logging_level() == LoggingLevels.DEBUG:
            return "%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)"
        else:
            return "%(message)s"

    COLORS: Dict[LoggingLevels, str] = {
        LoggingLevels.DEBUG: grey,
        LoggingLevels.INFO: white,
        LoggingLevels.WARN: yellow,
        LoggingLevels.ERROR: red,
        LoggingLevels.CRITICAL: bold_red,
    }

    def format(self, record: logging.LogRecord) -> str:
        """Converts a log record to a (colored) string.

        Args:
            record: LogRecord generated by the code.

        Returns:
            A string formatted according to specifications.
        """
        # Get the template
        format_template = self._get_format_template(record)

        # Apply step name prepending if enabled (for console display)
        message = record.getMessage()
        try:
            if step_names_in_console.get():
                message = _add_step_name_to_message(message)
        except Exception:
            # If we can't get step context, just use the original message
            pass

        # Create a new record with the modified message
        modified_record = logging.LogRecord(
            name=record.name,
            level=record.levelno,
            pathname=record.pathname,
            lineno=record.lineno,
            msg=message,
            args=(),
            exc_info=record.exc_info,
        )

        if ZENML_LOGGING_COLORS_DISABLED:
            # If color formatting is disabled, use the default format without colors
            formatter = logging.Formatter(format_template)
            return formatter.format(modified_record)
        else:
            # Use color formatting
            log_fmt = (
                self.COLORS[LoggingLevels(record.levelno)]
                + format_template
                + self.reset
            )
            formatter = logging.Formatter(log_fmt)
            formatted_message = formatter.format(modified_record)
            quoted_groups = re.findall("`([^`]*)`", formatted_message)
            for quoted in quoted_groups:
                formatted_message = formatted_message.replace(
                    "`" + quoted + "`",
                    self.reset
                    + self.purple
                    + quoted
                    + self.COLORS.get(LoggingLevels(record.levelno)),
                )

            # Format URLs
            url_pattern = r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+"
            urls = re.findall(url_pattern, formatted_message)
            for url in urls:
                formatted_message = formatted_message.replace(
                    url,
                    self.reset
                    + self.blue
                    + url
                    + self.COLORS.get(LoggingLevels(record.levelno)),
                )
            return formatted_message


def get_logging_level() -> LoggingLevels:
    """Get logging level from the env variable.

    Returns:
        The logging level.

    Raises:
        KeyError: If the logging level is not found.
    """
    verbosity = ZENML_LOGGING_VERBOSITY.upper()
    if verbosity not in LoggingLevels.__members__:
        raise KeyError(
            f"Verbosity must be one of {list(LoggingLevels.__members__.keys())}"
        )
    return LoggingLevels[verbosity]


def get_storage_log_level() -> LoggingLevels:
    """Get storage logging level from the env variable with safe fallback.

    Returns:
        The storage logging level, defaulting to INFO if invalid.

    Raises:
        KeyError: If the storage logging level is not found.
    """
    verbosity = ZENML_STORAGE_LOGGING_VERBOSITY.upper()
    if verbosity not in LoggingLevels.__members__:
        raise KeyError(
            f"Verbosity must be one of {list(LoggingLevels.__members__.keys())}"
        )
    return LoggingLevels[verbosity]


def set_root_verbosity() -> None:
    """Set the root verbosity."""
    level = get_logging_level()
    if level != LoggingLevels.NOTSET:
        if ENABLE_RICH_TRACEBACK:
            rich_tb_install(show_locals=(level == LoggingLevels.DEBUG))

        logging.root.setLevel(level=level.value)
        get_logger(__name__).debug(
            f"Logging set to level: {logging.getLevelName(level.value)}"
        )
    else:
        logging.disable(sys.maxsize)
        logging.getLogger().disabled = True
        get_logger(__name__).debug("Logging NOTSET")


def wrapped_print(*args: Any, **kwargs: Any) -> None:
    """Wrapped print function.

    Args:
        *args: Arguments to print
        **kwargs: Keyword arguments for print
    """
    original_print = getattr(builtins, "_zenml_original_print")

    file_arg = kwargs.get("file", sys.stdout)

    # IMPORTANT: Don't intercept internal calls to any objects
    # other than sys.stdout and sys.stderr. This is especially
    # critical for handling tracebacks. The default logging
    # formatter uses StringIO to format tracebacks, we don't
    # want to intercept it and create a LogRecord about it.
    if file_arg not in (sys.stdout, sys.stderr):
        original_print(*args, **kwargs)

    # Convert print arguments to message
    message = " ".join(str(arg) for arg in args)

    # Call active handlers first (for storage)
    if message.strip():
        handlers = logging_handlers.get()

        for handler in handlers:
            try:
                # Create a LogRecord for the handler
                record = logging.LogRecord(
                    name="print",
                    level=logging.ERROR
                    if file_arg == sys.stderr
                    else logging.INFO,
                    pathname="",
                    lineno=0,
                    msg=message,
                    args=(),
                    exc_info=None,
                )
                # Check if handler's level would accept this record
                if record.levelno >= handler.level:
                    handler.emit(record)
            except Exception:
                # Don't let handler errors break print
                pass

    if step_names_in_console.get():
        message = _add_step_name_to_message(message)

    # Then call original print for console display
    original_print(message, *args[1:], **kwargs)


def setup_global_print_wrapping() -> None:
    """Set up global print() wrapping with context-aware handlers."""
    capture_prints = handle_bool_env_var(
        ENV_ZENML_CAPTURE_PRINTS, default=True
    )

    if not capture_prints or hasattr(__builtins__, "_zenml_original_print"):
        return

    # Store original and replace print
    setattr(builtins, "_zenml_original_print", builtins.print)
    setattr(builtins, "print", wrapped_print)


def get_formatter() -> logging.Formatter:
    """Get a configured logging formatter.

    Returns:
        The formatter.
    """
    if log_format := os.environ.get(ENV_ZENML_LOGGING_FORMAT, None):
        return logging.Formatter(fmt=log_format)
    else:
        return CustomFormatter()


def get_console_handler() -> Any:
    """Get console handler for logging.

    Returns:
        A console handler.
    """
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setFormatter(get_formatter())
    # Set console handler level explicitly to console verbosity
    console_handler.setLevel(get_logging_level().value)
    return console_handler


def get_logger(logger_name: str) -> logging.Logger:
    """Main function to get logger name,.

    Args:
        logger_name: Name of logger to initialize.

    Returns:
        A logger object.
    """
    return logging.getLogger(logger_name)


def init_logging() -> None:
    """Initialize logging with default levels."""
    # Mute tensorflow cuda warnings
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    set_root_verbosity()

    # Check if console handler already exists to avoid duplicates
    root_logger = logging.getLogger()
    has_console_handler = any(
        isinstance(handler, logging.StreamHandler)
        and handler.stream == sys.stdout
        for handler in root_logger.handlers
    )

    # logging capture warnings
    logging.captureWarnings(True)

    if not has_console_handler:
        console_handler = logging.StreamHandler(sys.stdout)
        console_handler.setFormatter(get_formatter())
        # Set console handler level explicitly to console verbosity
        console_handler.setLevel(get_logging_level().value)
        root_logger.addHandler(console_handler)

    # Initialize global print wrapping
    setup_global_print_wrapping()

    # Enable logs if environment variable SUPPRESS_ZENML_LOGS is not set to True
    suppress_zenml_logs: bool = handle_bool_env_var(
        ENV_ZENML_SUPPRESS_LOGS, True
    )
    if suppress_zenml_logs:
        # suppress logger info messages
        suppressed_logger_names = [
            "urllib3",
            "azure.core.pipeline.policies.http_logging_policy",
            "grpc",
            "requests",
            "kfp",
            "tensorflow",
        ]
        for logger_name in suppressed_logger_names:
            logging.getLogger(logger_name).setLevel(logging.WARNING)

        # disable logger messages
        disabled_logger_names = [
            "rdbms_metadata_access_object",
            "backoff",
            "segment",
        ]
        for logger_name in disabled_logger_names:
            logging.getLogger(logger_name).setLevel(logging.WARNING)
            logging.getLogger(logger_name).disabled = True
